package utils
{
  // UART : Universal Asynchronous Receiver/Transmitter
  import chisel3._
  import chisel3.util._
  import bus._

  sealed trait uartconstants{
    val RegWidth = 32
    val startbits = 1
    val databits = 8
    val stopbits = 2
    val bitsRegWidth = log2Ceil(startbits + databits + stopbits)
  }

  sealed abstract class uartBundle extends Bundle with uartconstants
  sealed abstract class uartModule extends Module with uartconstants

  class UartIO extends uartBundle {
    val data = Output(Bits(databits.W))
  }

  class Tx(frequency: Int, baudRate: Int) extends uartModule {
    val io = IO(new Bundle {
      val txd = Output(Bits(1.W))
      val channel = Flipped(DecoupledIO(new UartIO()))
    })

    val BIT_CNT = ((frequency + baudRate / 2) / baudRate - 1).asUInt()

    val shiftReg = RegInit(0x7ff.U)
    val cntReg = RegInit(0.U(20.W))   // to generate correct baud rate
    val bitsReg = RegInit(0.U(bitsRegWidth.W))

    io.channel.ready := (cntReg === 0.U) && (bitsReg === 0.U)
    io.txd := shiftReg(0)

    when(cntReg === 0.U) {
      cntReg := BIT_CNT
      when(bitsReg =/= 0.U) {
        val shift = shiftReg >> 1
        shiftReg := Cat(1.U, shift(9, 0))
        bitsReg := bitsReg - 1.U
      }.otherwise {
        when(io.channel.valid) {
          // two stop bits , data , one start bit
          shiftReg := Cat(Cat(3.U, io.channel.bits.data), 0.U)
          bitsReg := 11.U
        }.otherwise {
          shiftReg := 0x7ff.U
        }
      }
    }.otherwise {
      cntReg := cntReg - 1.U
    }
  }

  // singal byte data buffer
  class Buffer extends Module {
    val io = IO(new Bundle {
      val in= Flipped(DecoupledIO(new UartIO()))
      val out = DecoupledIO(new UartIO())
    })
    val empty :: full :: Nil = Enum(2)
    val stateReg = RegInit(empty)
    val dataReg = RegInit(0.U(8.W))
    io.in.ready := stateReg === empty
    io.out.valid := stateReg === full
    when(stateReg === empty) {  // empty
      when(io.in.valid) {
        dataReg := io.in.bits.data
        stateReg := full
      }
    }.otherwise { // full
      when(io.out.ready) {
        stateReg := empty
      }
    }
    io.out.bits.data := dataReg
  }

  class Rx(frequency: Int, baudRate: Int) extends uartModule {
    val io = IO(new Bundle {
      val rxd = Input(Bits(1.W))
      val channel = DecoupledIO(new UartIO())
    })

    val BIT_CNT = ((frequency + baudRate / 2) / baudRate - 1).U
    val START_CNT = ((3 * frequency / 2 + baudRate / 2) / baudRate - 1).U

    // Sync in the asynchronous RX data
    // Reset to 1 to not start reading after a reset
    val rxReg = RegNext(RegNext(io.rxd, 1.U), 1.U)
    val shiftReg = RegInit(0.U(8.W))
    val cntReg = RegInit(0.U(20.W))
    val bitsReg = RegInit(0.U(4.W))
    val valReg = RegInit(false.B)

    when(cntReg =/= 0.U) {
      cntReg := cntReg - 1.U
    }.elsewhen(bitsReg =/= 0.U) {
      cntReg := BIT_CNT
      shiftReg := Cat(rxReg, shiftReg >> 1)
      bitsReg := bitsReg - 1.U
      // the last shifted in
      when(bitsReg === 1.U) {
        valReg := true.B
      }
      // wait 1.5 bits after falling edge of start
    }.elsewhen(rxReg === 0.U) {
      cntReg := START_CNT
      bitsReg := 8.U
    }

    when(valReg && io.channel.ready) {
      valReg := false.B
    }

    io.channel.bits.data := shiftReg
    io.channel.valid := valReg
  }

  class BufferedTx(frequency: Int, baudRate: Int) extends Module {
    val io = IO(new Bundle {
      val txd = Output(UInt(1.W))
      val channel = Flipped(DecoupledIO(new UartIO()))
    })
    val tx = Module(new Tx(frequency, baudRate))
    val buf = Module(new Buffer())

    buf.io.in <> io.channel
    tx.io.channel <> buf.io.out
    io.txd <> tx.io.txd
  }
  class Sender(frequency: Int, baudRate: Int) extends Module {
    val io = IO(new Bundle {
      val txd = Output(UInt(1.W))
    })

    val tx = Module(new BufferedTx(frequency, baudRate))

    io.txd := tx.io.txd

    val msg = "Hello World!"
    val text = VecInit(msg.map(_.U))
    val len = msg.length.U

    val cntReg = RegInit(0.U(8.W))

    tx.io.channel.bits.data := text(cntReg)
    tx.io.channel.valid := cntReg =/= len

    when(tx.io.channel.ready && cntReg =/= len) {
      cntReg := cntReg + 1.U
    }
  }
  class Echo(frequency: Int, baudRate: Int) extends Module {
    val io = IO(new Bundle {
      val txd = Output(UInt(1.W))
      val rxd = Input(UInt(1.W))
    })
    // io.txd := RegNext(io.rxd)
    val tx = Module(new BufferedTx(frequency, baudRate))
    val rx = Module(new Rx(frequency, baudRate))
    io.txd := tx.io.txd
    rx.io.rxd := io.rxd
    tx.io.channel <> rx.io.channel
  }

  class UartMain(frequency: Int, baudRate: Int) extends Module {
    val io = IO(new Bundle {
      val rxd = Input(UInt(1.W))
      val txd = Output(UInt(1.W))
    })

    val doSender = true

    if (doSender) {
      val s = Module(new Sender(frequency, baudRate))
      io.txd := s.io.txd
    } else {
      val e = Module(new Echo(frequency, baudRate))
      e.io.rxd := io.rxd
      io.txd := e.io.txd
    }

  }
}


